"""add prediction table

Revision ID: 85e99d360a52
Revises: 5c8cc5112b38
Create Date: 2023-03-06 11:16:54.923878

"""
import json
from typing import Any, List, Optional

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "85e99d360a52"
down_revision = "5c8cc5112b38"
branch_labels = None
depends_on = None

TaskTypeDatasetInfer = 15


def get_record(conn: Any, table: str, id_: Optional[int]) -> Any:
    if id_ is None:
        return None
    result = conn.execute(f"SELECT * FROM {table} WHERE id = {id_}").fetchall()
    return result[0] if result else None


def convert_dataset_to_prediction(conn: Any, datasets: List) -> Any:
    predictions = []
    for dataset in datasets:
        task = get_record(conn, "task", dataset.task_id)
        parameters = json.loads(task.parameters)
        extra = {
            "name": dataset.hash,
            "dataset_id": parameters["dataset_id"],
            "model_id": parameters["model_id"],
            "model_stage_id": parameters["model_stage_id"],
        }
        prediction = {**dataset, **extra}
        predictions.append(prediction)
    return predictions


def upgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    prediction_table = op.create_table(
        "prediction",
        sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
        sa.Column("name", sa.String(length=100), nullable=False),
        sa.Column("hash", sa.String(length=100), nullable=False),
        sa.Column("source", sa.SmallInteger(), nullable=False),
        sa.Column("description", sa.String(length=100), nullable=True),
        sa.Column("result_state", sa.SmallInteger(), nullable=False),
        sa.Column("user_id", sa.Integer(), nullable=False),
        sa.Column("project_id", sa.Integer(), nullable=False),
        sa.Column("task_id", sa.Integer(), nullable=False),
        sa.Column("dataset_id", sa.Integer(), nullable=False),
        sa.Column("model_id", sa.Integer(), nullable=False),
        sa.Column("model_stage_id", sa.Integer(), nullable=False),
        sa.Column("asset_count", sa.Integer(), nullable=True),
        sa.Column("keyword_count", sa.Integer(), nullable=True),
        sa.Column("keywords", sa.Text(length=20000), nullable=True),
        sa.Column("is_visible", sa.Boolean(), nullable=False),
        sa.Column("is_deleted", sa.Boolean(), nullable=False),
        sa.Column("create_datetime", sa.DateTime(), nullable=False),
        sa.Column("update_datetime", sa.DateTime(), nullable=False),
        sa.PrimaryKeyConstraint("id"),
    )
    with op.batch_alter_table("prediction", schema=None) as batch_op:
        batch_op.create_index(batch_op.f("ix_prediction_dataset_id"), ["dataset_id"], unique=False)
        batch_op.create_index(batch_op.f("ix_prediction_hash"), ["hash"], unique=True)
        batch_op.create_index(batch_op.f("ix_prediction_id"), ["id"], unique=False)
        batch_op.create_index(batch_op.f("ix_prediction_model_id"), ["model_id"], unique=False)
        batch_op.create_index(batch_op.f("ix_prediction_model_stage_id"), ["model_stage_id"], unique=False)
        batch_op.create_index(batch_op.f("ix_prediction_name"), ["name"], unique=False)
        batch_op.create_index(batch_op.f("ix_prediction_project_id"), ["project_id"], unique=False)
        batch_op.create_index(batch_op.f("ix_prediction_result_state"), ["result_state"], unique=False)
        batch_op.create_index(batch_op.f("ix_prediction_source"), ["source"], unique=False)
        batch_op.create_index(batch_op.f("ix_prediction_task_id"), ["task_id"], unique=False)
        batch_op.create_index(batch_op.f("ix_prediction_user_id"), ["user_id"], unique=False)

    conn = op.get_bind()
    try:
        # copy inference results from dataset table to prediction table
        datasets = conn.execute(f"SELECT * FROM dataset WHERE source = {TaskTypeDatasetInfer}").fetchall()
        predictions = convert_dataset_to_prediction(conn, datasets)
        op.bulk_insert(prediction_table, predictions)
        conn.execute(f"UPDATE dataset SET is_deleted = 1 WHERE source = {TaskTypeDatasetInfer}")
    except Exception as e:
        print("Could not migrate inference results to predictions, skip: %s" % e)
    # ### end Alembic commands ###


def downgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("prediction", schema=None) as batch_op:
        batch_op.drop_index(batch_op.f("ix_prediction_user_id"))
        batch_op.drop_index(batch_op.f("ix_prediction_task_id"))
        batch_op.drop_index(batch_op.f("ix_prediction_source"))
        batch_op.drop_index(batch_op.f("ix_prediction_result_state"))
        batch_op.drop_index(batch_op.f("ix_prediction_project_id"))
        batch_op.drop_index(batch_op.f("ix_prediction_name"))
        batch_op.drop_index(batch_op.f("ix_prediction_model_stage_id"))
        batch_op.drop_index(batch_op.f("ix_prediction_model_id"))
        batch_op.drop_index(batch_op.f("ix_prediction_id"))
        batch_op.drop_index(batch_op.f("ix_prediction_hash"))
        batch_op.drop_index(batch_op.f("ix_prediction_dataset_id"))

    op.drop_table("prediction")
    # ### end Alembic commands ###
